Skip to content

Conversation

@jessicaliu06
Copy link
Collaborator

supertensor.py

  • Represents a tensor using the following attributes:
    • a base tensor
    • a map that encodes which modes of the logical tensor are combined into each mode of the base tensor, and in what order these logical modes are combined.
  • Supports item access using logical coordinates.

interpreter.py

  • Input: an einsum AST where the tensor name of each Access node is bound to a SuperTensor, and the indices of each Access node are logical indices on the SuperTensor
    • Note: The base tensor of a SuperTensor in the input AST will have some arbitrary shape.
  • Algorithm:
    • Collect all Access nodes in the einsum expression.
    • Reshape the base tensors to combine index groups in a useful way.
    • Bind each tensor name to the correctly-shaped base tensor.
    • Rewrite each Access node with the proper indices to access the correctly-shaped base tensor.
    • Call the standard einsum interpreter on the rewritten AST to compute the base tensor of the output tensor.
    • Wrap the output base tensor into an output SuperTensor.

test.py

  • Contains a simple test case, for which the output of the SuperTensor interpreter matches the output of the regular interpreter.

@jessicaliu06 jessicaliu06 self-assigned this Nov 11, 2025
Copy link
Member

@willow-ahrens willow-ahrens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the correct logic! For readability and maintainability, consider using generator expressions, sets, and dicts more frequently. Take a look at my specific comments for more detail.

for child in curr.children:
postorder(child)

postorder(node)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use PostOrderDFS from symbolic.py, inline this function

Returns:
`List[Tuple[FrozenSet[str], List[str]]]`
A list of tuples, each containing a set of tensor names and the corresponding list of indices that appear in exactly those tensors.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this function, would it be possible to instead just construct a dictionary mapping sets of tensors to lists of indices, then for each index compute the set of tensors and add that index to the corresponding list in the dictionary?

idx_groups = Dict[Index, Set[Alias]]()
for node in PostOrderDFS(einsum):
   match node:
      case Access(tns, idxs):
         for idx in idxs:
             idx_groups.setdefault(idx, Set[Alias]()).add(tns)
group_idxs = Dict[Tuple[Alias], Set[Index]]()
for idx, group in idx_groups:
    tns_groups[group].setdefault(tuple(sort(group)), Set[Index]()).add(idx)

Something like this might be simpler, could you try to refactor a bit to simplify?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also inline this logic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the detailed feedback! I applied these changes and I think it definitely made the code significantly more concise and readable.

# Assign a new index name to each group of original indices.
new_idxs = {}
for k, (tensor_set, _) in enumerate(idx_groups):
new_idxs[tensor_set] = f"i{k}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use Namespace from Symbolic.py to create fresh index variable names.


corrected_bindings = {}
corrected_idx_lists = {}
for tns_name, supertensor, input_idx_list in inputs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would fuse this list with the previous, they are the same loop, basically.

corrected_idx_lists = {}
for tns_name, supertensor, input_idx_list in inputs:
new_idx_list = []
mode_map = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can construct a dictionary globally, which maps idx -> newidx, which I think would be helpful here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_idx_list = sort(list(set(global_idx_map[idx] for idx in access.idxs)))
mode_map = [[access.idxs.index(idx) for idx in idx_groups[new_idx] if idx in access.idxs] for new_idx in new_idx_list]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants